import torch
from torch.utils.data import Dataset


class Covid19Dataset(Dataset):
    """
    x: Features.
    y: Labels, if none, do prediction.
    """

    def __init__(self, feats, labels=None):
        if labels is None:
            self.y = labels
        else:
            self.y = torch.FloatTensor(labels)
        self.x = torch.FloatTensor(feats)

    def __getitem__(self, index):
        if self.y is None:
            return self.x[index]
        else:
            return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)
